#include "listed.hpp"
#include "cell.h"
#include "comm.h"
#include "comm_template.hpp"
#include "esmd_conf.h"
#include "esmd_types.h"
#include "hmini.hpp"
#include "htab_cpp.hpp"
#include "listed_impl.hpp"
#include "shake.h"
#include "utils.hpp"
#include "timer.h"
#include <algorithm>

template class listed_force_grid<harmonic_bond_param, harmonic_bond_param::cell_cap>;
template class listed_force_grid<harmonic_angle_param, harmonic_angle_param::cell_cap>;
template class listed_force_grid<urey_bradly_param, urey_bradly_param::cell_cap>;
template class listed_force_grid<cosine_torsion_param, cosine_torsion_param::cell_cap>;
template class listed_force_grid<ryckaert_bellmans_param, ryckaert_bellmans_param::cell_cap>;
template class listed_force_grid<extended_ryckaert_bellmans_param, extended_ryckaert_bellmans_param::cell_cap>;
template class listed_force_grid<harmonic_improper_param, harmonic_improper_param::cell_cap>;
template class listed_force_grid<rigid_param, rigid_param::cell_cap>;
template class listed_force_grid<special_pair, special_pair::cell_cap>;
void topology_grids::forward_comm_export(mpp_t *mpp, cellgrid_t *grid){
  auto packer = export_packer(*this);
  auto unpacker = export_unpacker(*this);
  forward_comm_template(mpp, grid, packer, unpacker);  
}

DEF_TIMER (UPDATE, "topology update");
extern void topology_grids_export_sw (cellgrid_t *grid);
extern void topology_grids_import_sw(cellgrid_t *grid);
void topology_grids::update(mpp_t *mpp, cellgrid_t *grid){
  timer_start(UPDATE);
  #ifndef __sw_64__
  FOREACH_LOCAL_CELL(grid, cx, cy, cz, icell) {
    mini_hset<tagint, CELL_CAP * 16 / 3> tagset;
    for (int i = 0; i < icell->natom; i ++){
      tagset.insert(icell->tag[i]);
    }
    
    do_export(get_offset_xyz<true>(grid, cx, cy, cz), tagset);
  }
  #else
  topology_grids_export_sw(grid);
  #endif
  forward_comm_export(mpp, grid);
  #ifndef __sw_64__
  int nbond = 0, nangle = 0, nimpr = 0, ntor = 0;
  FOREACH_LOCAL_CELL(grid, cx, cy, cz, icell) {
    mini_htab<tagint, int, CELL_CAP * 16 / 3> tagmap;
    for (int i = 0; i < icell->natom; i ++){
      tagmap[icell->tag[i]] = i;
    }
    int ioff = get_offset_xyz<true>(grid, cx, cy, cz);
    mini_hset<tagint, MAX_CELL_GUEST * 4 / 3> guest_set;
    do_import_neighbors(ioff, tagmap, guest_set, grid, cx, cy, cz, utils::make_iseq<int, TOP_END>{});

    auto &special_cell = get<special_pair>().cells[get_offset_xyz<true>(grid, cx, cy, cz)];
    typedef decltype(special_cell.by_tag.entries[0]) special_ent;
    
    int guest_offset = CELL_CAP;
    int nexcl = 0, nscal = 0;
    FOREACH_NEIGHBOR(grid, cx, cy, cz, dx, dy, dz, jcell) {
      int did = hdcell(grid->nn, dx, dy, dz);
      icell->first_guest_cell[did] = guest_offset - CELL_CAP;
      icell->first_excl_cell[did] = nexcl;
      icell->first_scal_cell[did] = nscal;
      mini_htab<tagint, int, CELL_CAP * 4 / 3> jtab;
      for (int j = 0; j < jcell->natom; j ++) {
        jtab[jcell->tag[j]] = j;
      }

      for (auto &special_ent : special_cell.by_tag){
        if (jtab.contains(special_ent.id[1])) {
          int i = tagmap[special_ent.id[0]];
          int j = jtab[special_ent.id[1]];
          if (special_ent.pid == special_pair::EXCL){
            icell->excl_id[nexcl][0] = i;
            icell->excl_id[nexcl][1] = j;
            nexcl ++;
          } else if (special_ent.pid == special_pair::SCAL) {
	    #ifdef __sw__
            if (icell == jcell && i > j) continue;
	    #endif
            icell->scal_id[nscal][0] = i;
            icell->scal_id[nscal][1] = j;
            nscal ++;
          }
        }
      }
#ifndef __sw__
      struct i2 {
	int i, j;
      };
      i2 *tmp_excl = (i2*)icell->excl_id;
      std::sort(tmp_excl + icell->first_excl_cell[did], tmp_excl + nexcl, [](const i2 &e1, const i2 &e2) -> bool {return e1.j < e2.j;});
      i2 *tmp_scal = (i2*)icell->scal_id;
      std::sort(tmp_scal + icell->first_scal_cell[did], tmp_scal + nscal, [](const i2 &e1, const i2 &e2) -> bool {return e1.j < e2.j;});
#endif
      if (icell == jcell) continue;
      for (int j = 0; j < jcell->natom; j ++){
        if (guest_set.contains(jcell->tag[j])) {
          icell->t[guest_offset] = jcell->t[j];
          icell->rmass[guest_offset] = jcell->rmass[j];
          icell->tag[guest_offset] = jcell->tag[j];
          icell->guest_id[guest_offset - CELL_CAP] = j;
          guest_offset ++;
        }
      }
      
    }
    

    icell->nguest = guest_offset - CELL_CAP;
    icell->first_guest_cell[hdcell(grid->nn, grid->nn, grid->nn, grid->nn+1)] = guest_offset - CELL_CAP;
    icell->first_excl_cell[hdcell(grid->nn, grid->nn, grid->nn, grid->nn+1)] = nexcl;
    icell->first_scal_cell[hdcell(grid->nn, grid->nn, grid->nn, grid->nn+1)] = nscal;
    for (int i = 0; i < icell->nguest; i ++){
      tagmap[icell->tag[CELL_CAP + i]] = CELL_CAP + i;
    }
    
    relink(get_offset_xyz<true>(grid, cx, cy, cz), tagmap);
    auto &rigid_cell = get<rigid_param>().cells[get_offset_xyz<true>(grid, cx, cy, cz)];
    int oldtype[CELL_CAP];
    for (int i = 0; i < icell->natom; i ++){
      icell->rigid[i].type = RIGID_NONE;
    }
    for (int i = 0; i < rigid_cell.by_id.cnt; i ++){
      auto &ent = rigid_cell.by_id.entries[i];
      int owner = ent.id[0];
      auto &param = get<rigid_param>().param[ent.pid];
      auto &rig = icell->rigid[owner];

      icell->rigid[owner].id[0] = ent.id[1];
      icell->rigid[owner].id[1] = ent.id[2];
      icell->rigid[owner].id[2] = ent.id[3];
      icell->rigid[owner].type = param.type - 1;
      icell->rigid[owner].r0[0] = param.r0[0];
      icell->rigid[owner].r0[1] = param.r0[1];
      icell->rigid[owner].r0[2] = param.r0[2];
    }
  }
  #else
  topology_grids_import_sw(grid);
  #endif
  timer_stop(UPDATE);
}

void topology_grids::compute(mdstat_t *stat, cellgrid_t *grid){
  FOREACH_LOCAL_CELL(grid, cx, cy, cz, icell) {
    int offset = get_offset_xyz<true>(grid, cx, cy, cz);
    if (present[TOP_HARMONIC_BOND]) stat->ebond += get<TOP_HARMONIC_BOND>().cells[offset].calc(icell->f, icell->x, get<TOP_HARMONIC_BOND>().param);
    if (present[TOP_UREY_BRADLY]) stat->eangle += get<TOP_UREY_BRADLY>().cells[offset].calc(icell->f, icell->x, get<TOP_UREY_BRADLY>().param);
    if (present[TOP_COSINE_TORSION]) stat->etori += get<TOP_COSINE_TORSION>().cells[offset].calc(icell->f, icell->x, get<TOP_COSINE_TORSION>().param);
    if (present[TOP_HARMONIC_IMPROPER]) stat->eimpr += get<TOP_HARMONIC_IMPROPER>().cells[offset].calc(icell->f, icell->x, get<TOP_HARMONIC_IMPROPER>().param);
  }
}
void gather_guest_x(cellgrid_t *grid){
  FOREACH_LOCAL_CELL(grid, cx, cy, cz, icell) {
    FOREACH_NEIGHBOR(grid, cx, cy, cz, dx, dy, dz, jcell) {
      int dcell = hdcell(grid->nn, dx, dy, dz);
      for (int i = icell->first_guest_cell[dcell]; i < icell->first_guest_cell[dcell + 1]; i ++){
        icell->x[CELL_CAP + i] = jcell->x[icell->guest_id[i]] + (jcell->basis - icell->basis);
        icell->f[CELL_CAP + i] = 0;
      }
    }
  }
}

extern void gather_guest_x_sw (cellgrid_t *grid);
extern void compute_listed_forces_sw(mdstat_t *stat, cellgrid_t *grid);
DEF_TIMER (GATHER, "gather guest x")
DEF_TIMER (SCATTER, "scatter guest f")
DEF_TIMER (COMPUTE_BONDED, "compute bonded")
void compute_listed_forces(cellgrid_t *grid, void *param, mdstat_t *stat){
  //puts("begin listed");
  //puts("begin compute");
  timer_start(GATHER);
  #ifndef __sw__
  gather_guest_x(grid);
  #else
  //gather_guest_x_sw(grid);
  #endif
  timer_stop(GATHER);
  timer_start(COMPUTE_BONDED);
  compute_listed_forces_sw(stat, grid);
  //grid->topo->compute(stat, grid);
  timer_stop(COMPUTE_BONDED);
  //puts("begin scatter");
  timer_start(SCATTER);
  // FOREACH_LOCAL_CELL(grid, cx, cy, cz, icell) {
  //   FOREACH_NEIGHBOR(grid, cx, cy, cz, dx, dy, dz, jcell) {
  //     int dcell = hdcell(grid->nn, dx, dy, dz);
  //     for (int i = icell->first_guest_cell[dcell]; i < icell->first_guest_cell[dcell + 1]; i ++){
  //       jcell->f[icell->guest_id[i]] += icell->f[CELL_CAP + i];
  //     }
  //   }
  // }
  timer_stop(SCATTER);
}

